-
Notifications
You must be signed in to change notification settings - Fork 452
Use MaxText max_segments_per_seq config variable to control Grain batch packing #2774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use MaxText max_segments_per_seq config variable to control Grain batch packing #2774
Conversation
yeandy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may need to add max_sequences_per_bin=config.max_segments_per_seq in make_hf_eval_iterator too
Done, thanks for the tip |
|
Can you also update here: https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/types.py#L882? And apply max_sequences_per_bin here: https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/input_pipeline/_grain_data_processing.py#L247 |
9f3b716 to
ecae099
Compare
Done |
aireenmei
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
1c5b56e to
d107901
Compare
|
@aireenmei I apologize but I needed to make some updates for the code quality checks (they didn't run before your approval). Can you please take another look. These were minor edits. |
|
Could you take a look to see if this error is related: https://github.com/AI-Hypercomputer/maxtext/actions/runs/20354898076/job/58493556179?pr=2774 |
|
I've asked @SurbhiJainUSC to also take a look. @gabeweisz you will need to squash your commits into 1 before merging |
d107901 to
f7971f2
Compare
|
@aireenmei thanks for the feedback - I have done the squash commit |
|
Looks like I still need an approval from a maintainer. How can I get one? |
Description
When using THD format packed data with TransformerEngine, the user must specify the maximum number of segments that can be packed into a sequence at Jax JIT time. If grain packs more segments than allowed, then this can cause crashes or data corruption.
We have previously updated grain to allow limiting the number of segments to pack into a sequence, and this PR takes the appropriate value from the MaxText configuration and passes it to Grain
Tests
We have had this fix in place in our AMD fork of MaxText for some time, but needed to get the Grain fix upstreamed first before creating this PR.
We have tested this fix extensively internally and have customers using it in production.
MaxText does not currently have any tests that use packed batches, but I can create some if needed.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.